"""
---
title: Training a U-Net on Carvana dataset
summary: >
  Code for training a U-Net model on Carvana dataset.
---

# Training [U-Net](index.html)

This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html).
You can find the download instructions
[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).

Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.

For simplicity, we do not do a training and validation split.
"""

import numpy as np
import torchvision.transforms.functional

import torch
import torch.utils.data
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.unet import UNet
from labml_nn.unet.carvana import CarvanaDataset
from torch import nn


class Configs(BaseConfigs):
    """
    ## Configurations
    """
    # Device to train the model on.
    # [`DeviceConfigs`](../helpers/device.html)
    #  picks up an available CUDA device or defaults to CPU.
    device: torch.device = DeviceConfigs()

    # [U-Net](index.html) model
    model: UNet

    # Number of channels in the image. $3$ for RGB.
    image_channels: int = 3
    # Number of channels in the output mask. $1$ for binary mask.
    mask_channels: int = 1

    # Batch size
    batch_size: int = 1
    # Learning rate
    learning_rate: float = 2.5e-4

    # Number of training epochs
    epochs: int = 4

    # Dataset
    dataset: CarvanaDataset
    # Dataloader
    data_loader: torch.utils.data.DataLoader

    # Loss function
    loss_func = nn.BCELoss()
    # Sigmoid function for binary classification
    sigmoid = nn.Sigmoid()

    # Adam optimizer
    optimizer: torch.optim.Adam

    def init(self):
        # Initialize the [Carvana dataset](carvana.html)
        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
                                      lab.get_data_path() / 'carvana' / 'train_masks')
        # Initialize the model
        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)

        # Create dataloader
        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
                                                       shuffle=True, pin_memory=True)
        # Create optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        # Image logging
        tracker.set_image("sample", True)

    @torch.no_grad()
    def sample(self, idx=-1):
        """
        ### Sample images
        """

        # Get a random sample
        x, _ = self.dataset[np.random.randint(len(self.dataset))]
        # Move data to device
        x = x.to(self.device)

        # Get predicted mask
        mask = self.sigmoid(self.model(x[None, :]))
        # Crop the image to the size of the mask
        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
        # Log samples
        tracker.save('sample', x * mask)

    def train(self):
        """
        ### Train for an epoch
        """

        # Iterate through the dataset.
        # Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix)
        # to sample $50$ times per epoch.
        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
            # Increment global step
            tracker.add_global_step()
            # Move data to device
            image, mask = image.to(self.device), mask.to(self.device)

            # Make the gradients zero
            self.optimizer.zero_grad()
            # Get predicted mask logits
            logits = self.model(image)
            # Crop the target mask to the size of the logits. Size of the logits will be smaller if we
            # don't use padding in convolutional layers in the U-Net.
            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
            # Calculate loss
            loss = self.loss_func(self.sigmoid(logits), mask)
            # Compute gradients
            loss.backward()
            # Take an optimization step
            self.optimizer.step()
            # Track the loss
            tracker.save('loss', loss)

    def run(self):
        """
        ### Training loop
        """
        for _ in monit.loop(self.epochs):
            # Train the model
            self.train()
            # New line in the console
            tracker.new_line()
            # Save the model


def main():
    # Create experiment
    experiment.create(name='unet')

    # Create configurations
    configs = Configs()

    # Set configurations. You can override the defaults by passing the values in the dictionary.
    experiment.configs(configs, {})

    # Initialize
    configs.init()

    # Set models for saving and loading
    experiment.add_pytorch_models({'model': configs.model})

    # Start and run the training loop
    with experiment.start():
        configs.run()


#
if __name__ == '__main__':
    main()
